from random import *

# A simple, fixed gridworld. The goal is to dig up gold and deposit it into the deposit. 
# States are 0-15, with 0 being the bottom left, 3 being the bottom right, 15 being the top right.
# Actions: move in the 4 directions (0 = left,1 = right,2 = up,3 = down), and dig (4).
# RM state:
# 0 = haven't acquired gold
# 1 = acquired gold, but haven't deposited
# 2 = deposited gold

# Actual Grid (S = agent start, G = gold, D = depot:
# [S, 0, 0, G]
# [0, 0, 0 ,G]
# [0, 0, 0, G]
# [D, 0, 0, G]

has_gold = [False for i in range(16)]
has_gold[3] = has_gold[7] = has_gold[11] = has_gold[15] = True

depot = 12

has_gold_errors_uniform = [0.21, 0.15, 0.12, -0.13, 
                    0.14, 0.21, 0.2, -0.1, 
                    0.26, 0.23, 0.15, -0.16, 
                    0, 0.17, 0.1, -0.16]

has_gold_errors_fp = [0, 0, 0, 0,
                    0.6, 0, 0, 0,
                    0, 0, 0, 0,
                    0, 0, 0, 0]

# has_gold_errors_fn = [0, 0, -0.6, 0,
#                     0, 0, 0, 0,
#                     0, -0.6, 0, 0,
#                     -0.6, 0, 0, 0]

has_gold_errors = has_gold_errors_uniform



discount = 0.97
eps = 0.2
lr = 0.01
max_frames = 1e6

class MiningEnv:
    def __init__(self):
        self.pos = 0
        self.rm_state = 0

    def reset(self):
        self.pos = 0
        self.rm_state = 0

        return (self.pos, self.rm_state)

    def step(self, action):
        reward = 0
        done = False

        if action in [0,1,2,3]:
            reward -= 0.05

        # Left
        if action == 0:
            if self.pos % 4 != 0:
                self.pos -= 1
        # Right
        elif action == 1:
            if self.pos % 4 != 3:
                self.pos += 1
        # Up
        elif action == 2:
            if self.pos < 12:
                self.pos += 4
        # Down
        elif action == 3:
            if self.pos > 3:
                self.pos -= 4

        # Dig
        elif action == 4:
            if has_gold[self.pos] and self.rm_state == 0:
                self.rm_state = 1

        # Check if we're at the storage
        if self.pos == depot:
            done = True

            if self.rm_state == 1:
                self.rm_state = 2
                reward += 1

            else:
                reward = 0

        return (self.pos, self.rm_state), reward, done, None

# Q-learning with the perfect reward machine
class QLPerfectRM:
    def __init__(self):
        self.discount = discount
        self.eps = eps
        self.lr = lr 
        self.max_frames = max_frames

        self.q = { (pos, rm_state, a): 0 if rm_state == 2 or pos == depot else (random() - 0.5) for pos in range(16) for rm_state in range(3) for a in range(5) }
        self.env = MiningEnv()

    def train(self):

        eps_num = 0
        frames = 0

        while frames < self.max_frames:
            state = self.env.reset()
            eps_num += 1
            eps_len = 0
            returnn = 0

            while True:
                if random() < self.eps:
                    action = randint(0,4)
                else:
                    action = self.get_best_action(state)

                next_state, reward, done, _ = self.env.step(action)

                # Update Q values
                self.q[(*state, action)] += self.lr * (reward + self.discount * self.get_state_value(next_state) - self.q[(*state, action)])
                
                state = next_state
                eps_len += 1
                returnn += reward * self.discount ** (eps_len-1)

                if done:
                    #print("Episode %d done. Return is %.2f. Episode length is %d" %(eps_num, returnn, eps_len))
                    frames += eps_len
                    break

    def get_state_value(self, state):
        best_value = -10000000000

        for a in range(5):
            best_value = max(best_value, self.q[(*state, a)])
        
        return best_value


    def get_best_action(self, state):
        best_action = -1
        best_q = -10000000000

        for a in range(5):
            qsa = self.q[(*state, a)]
            if qsa > best_q:
                best_action = a
                best_q = qsa

        return best_action


    # Visualization
    def print_q(self):
        for k,v in self.q.items():
            print(k, round(v, 2))

    def eval_episode(self):
        #print("Visualizing 1 episode...")
        state = self.env.reset()
        eps_len = 0
        returnn = 0

        while True:
            action = self.get_best_action(state)
            #print(state, action)
            next_state, reward, done, _ = self.env.step(action)

            state = next_state
            eps_len += 1
            returnn += reward * self.discount ** (eps_len - 1)

            if done or eps_len == 100:
                print("Episode done. Return is %.2f." %(returnn))
                break



# Q-learning with belief thresholding. `has_gold_errors` shows the error model, and `error_eps` (0 to infinity) controls the degree of error (0 means no error)
class QLBeliefThresholding:
    def __init__(self):
        self.discount = discount
        self.eps = eps
        self.lr = lr 
        self.max_frames = max_frames
        self.error_eps = error_eps

        self.q = { (pos, rm_state, a): 0 if pos == depot else (random() - 0.5) for pos in range(16) for rm_state in range(3) for a in range(5) }
        self.env = MiningEnv()

    def train(self):

        eps_num = 0
        frames = 0

        while frames < self.max_frames:
            state = self.env.reset()
            rm_state_pred = 0 
            eps_num += 1
            eps_len = 0
            returnn = 0

            while True:
                if random() < self.eps:
                    action = randint(0,4)
                else:
                    action = self.get_best_action((state[0], rm_state_pred))

                next_state, reward, done, _ = self.env.step(action)

                # Update `rm_state_pred`
                next_rm_state_pred = rm_state_pred
                if rm_state_pred == 0:
                    if action == 4:
                        if has_gold[next_state[0]] and has_gold_errors[next_state[0]] * self.error_eps > -0.5:
                            next_rm_state_pred = 1
                        elif not has_gold[next_state[0]] and has_gold_errors[next_state[0]] * self.error_eps > 0.5:
                            next_rm_state_pred = 1
                elif rm_state_pred == 1:
                    if next_state[0] == depot:
                        next_rm_state_pred = 2
                    

                # Update Q values
                self.q[(state[0], rm_state_pred, action)] += self.lr * (reward + self.discount * self.get_state_value((next_state[0], next_rm_state_pred)) - self.q[(state[0], rm_state_pred, action)])
                
                state = next_state
                rm_state_pred = next_rm_state_pred
                eps_len += 1
                returnn += reward * self.discount ** (eps_len - 1)

                if done:
                    #print("Episode %d done. Return is %.2f. Episode length is %d" %(eps_num, returnn, eps_len))
                    frames += eps_len
                    break

    def get_state_value(self, state):
        best_value = -10000000000

        for a in range(5):
            best_value = max(best_value, self.q[(*state, a)])
        
        return best_value


    def get_best_action(self, state):
        best_action = -1
        best_q = -10000000000

        for a in range(5):
            qsa = self.q[(*state, a)]
            if qsa > best_q:
                best_action = a
                best_q = qsa

        return best_action


    # Visualization
    def print_q(self):
        for k,v in self.q.items():
            print(k, round(v, 2))

    def eval_episode(self):
        #print("Visualizing 1 episode...")
        state = self.env.reset()
        rm_state_pred = 0 

        eps_len = 0
        returnn = 0

        while True:
            action = self.get_best_action((state[0], rm_state_pred))
            next_state, reward, done, _ = self.env.step(action)

            # Update `rm_state_pred`
            next_rm_state_pred = rm_state_pred
            if rm_state_pred == 0:
                if action == 4:
                    if has_gold[next_state[0]] and has_gold_errors[next_state[0]] * self.error_eps > -0.5:
                        next_rm_state_pred = 1
                    elif not has_gold[next_state[0]] and has_gold_errors[next_state[0]] * self.error_eps > 0.5:
                        next_rm_state_pred = 1
            elif rm_state_pred == 1:
                if next_state[0] == depot:
                    next_rm_state_pred = 2

            
            state = next_state
            rm_state_pred = next_rm_state_pred
            eps_len += 1
            returnn += reward * self.discount ** (eps_len - 1)

            if done or eps_len == 100:
                print("Episode done. Return is %.2f. Episode length is %d" %(returnn, eps_len))
                break

# Q-learning with independent belief updating
# The Q-values are conditioned on s,a, and a belief of the reward machine state.
# We make the simplifying assumption that the Q-values are linear in terms of the RM state belief components. 
class QLIndependentBelief:
    def __init__(self, decorrelate=False):
        self.discount = discount
        self.eps = eps
        self.lr = lr 
        self.max_frames = max_frames
        self.error_eps = error_eps
        self.decorrelate = decorrelate

        self.q = { (pos, rm_state, a): 0 if rm_state == 2 or pos == depot else (random() - 0.5) for pos in range(16) for rm_state in range(3) for a in range(5) }
        self.env = MiningEnv()

    def train(self):

        eps_num = 0
        frames = 0

        while frames < self.max_frames:
            state = self.env.reset()[0]
            rm_belief = (1,0,0)
            dug = [False] * 16

            eps_num += 1
            eps_len = 0
            returnn = 0

            while True:
                if random() < self.eps:
                    action = randint(0,4)
                else:
                    action = self.get_best_action(state, rm_belief)
                next_state, reward, done, _ = self.env.step(action)
                next_state = next_state[0]

                # Compute next rm_belief
                if next_state == depot:
                    next_rm_belief = (rm_belief[0], 0, rm_belief[1])
                elif action == 4:
                    if self.decorrelate and dug[next_state]:
                        p1 = rm_belief[1]
                    else:
                        p1 = rm_belief[1] + rm_belief[0] * self.get_prob_gold(next_state)
                    next_rm_belief = (1-p1, p1, 0)
                    dug[next_state] = True
                else:
                    next_rm_belief = rm_belief

                # Update Q values
                delta = reward + self.discount * self.get_state_value(next_state, next_rm_belief) - self.get_q_value(state, rm_belief, action)
                self.q[(state, 0, action)] += self.lr * delta * rm_belief[0]
                self.q[(state, 1, action)] += self.lr * delta * rm_belief[1]
                self.q[(state, 2, action)] += self.lr * delta * rm_belief[2]
                
                state = next_state
                rm_belief = next_rm_belief
                eps_len += 1
                returnn += reward * self.discount ** (eps_len - 1)

                if done:
                    #print("Episode %d done. Return is %.2f. Episode length is %d" %(eps_num, returnn, eps_len))
                    frames += eps_len
                    break

    def get_state_value(self, state, rm_belief):
        best_value = -10000000000

        for a in range(5):
            best_value = max(best_value, self.get_q_value(state, rm_belief, a))
        
        return best_value


    def get_best_action(self, state, rm_belief):
        best_action = -1
        best_q = -10000000000

        for a in range(5):
            qsa = self.get_q_value(state, rm_belief, a)
            if qsa > best_q:
                best_action = a
                best_q = qsa

        return best_action

    def get_q_value(self, state, rm_belief, a):
        return rm_belief[0] * self.q[(state, 0, a)] + rm_belief[1] * self.q[(state, 1, a)] + rm_belief[2] * self.q[(state, 2, a)]

    def get_prob_gold(self, state):
        return min(1, max(0, int(has_gold[state]) + has_gold_errors[state] * self.error_eps))


    # Visualization
    def print_q(self):
        for k,v in self.q.items():
            print(k, round(v, 2))

    def eval_episode(self):
        #print("Visualizing 1 episode...")
        state = self.env.reset()[0]
        rm_belief = (1,0,0)
        dug = [False] * 16

        eps_len = 0
        returnn = 0

        while True:
            action = self.get_best_action(state, rm_belief)
            next_state, reward, done, _ = self.env.step(action)

            #print("State:", state, "RM belief:", rm_belief, "Action:", action)
            next_state = next_state[0]

            # Compute next rm_belief
            if next_state == depot:
                next_rm_belief = (rm_belief[0], 0, rm_belief[1])
            elif action == 4:
                if self.decorrelate and dug[next_state]:
                    p1 = rm_belief[1]
                else:
                    p1 = rm_belief[1] + rm_belief[0] * self.get_prob_gold(next_state)
                next_rm_belief = (1-p1, p1, 0)
                dug[next_state] = True
            else:
                next_rm_belief = rm_belief

            
            state = next_state
            rm_belief = next_rm_belief
            eps_len += 1
            returnn += reward * self.discount ** (eps_len - 1)

            if done or eps_len == 100:
                print("Episode done. Return is %.2f. Episode length is %d" %(returnn, eps_len))
                break

for i in range(8):
    error_eps = 0
    algo = QLPerfectRM()
    algo.train()
    algo.eval_episode()
# algo.print_q()


